/**
 * @file convolution_test.cpp
 * @author Marcus Edel
 *
 * Tests for various convolution strategies.
 */
#include <mlpack/core.hpp>

#include <mlpack/methods/ann/pooling_rules/max_pooling.hpp>
#include <mlpack/methods/ann/pooling_rules/mean_pooling.hpp>

#include <boost/test/unit_test.hpp>
#include "old_boost_test_definitions.hpp"

using namespace mlpack;
using namespace mlpack::ann;

BOOST_AUTO_TEST_SUITE(PoolingTest);

/**
 * Test the max pooling rule.
 */
BOOST_AUTO_TEST_CASE(MaxPoolingTest)
{
  // The data was generated by magic(6) in MATLAB.
  arma::mat input, output;
  input << 35 << 1 << 6 << 26 << 19 << 24 << arma::endr
        << 3 << 32 << 7 << 21 << 23 << 25 << arma::endr
        << 31 << 9 << 2 << 22 << 27 << 20 << arma::endr
        << 8 << 28 << 33 << 17 << 10 << 15 << arma::endr
        << 30 << 5 << 34 << 12 << 14 << 16 << arma::endr
        << 4 << 36 << 29 << 13 << 18 << 11;

  // Expected output of the generated 6 x 6 matrix.
  const double poolingOutput = 36;

  MaxPooling poolingRule;

  // Test the pooling function.
  BOOST_REQUIRE_EQUAL(poolingRule.Pooling(input), poolingOutput);

  // Test the unpooling function.
  poolingRule.Unpooling(input, input.max(), output);
  BOOST_REQUIRE_EQUAL(arma::accu(output), input.max());
}

/**
 * Test the mean pooling rule.
 */
BOOST_AUTO_TEST_CASE(MeanPoolingTest)
{
  // The data was generated by magic(6) in MATLAB.
  arma::mat input, output;
  input << 35 << 1 << 6 << 26 << 19 << 24 << arma::endr
        << 3 << 32 << 7 << 21 << 23 << 25 << arma::endr
        << 31 << 9 << 2 << 22 << 27 << 20 << arma::endr
        << 8 << 28 << 33 << 17 << 10 << 15 << arma::endr
        << 30 << 5 << 34 << 12 << 14 << 16 << arma::endr
        << 4 << 36 << 29 << 13 << 18 << 11;

  // Expected output of the generated 6 x 6 matrix.
  const double poolingOutput = 18.5;

  MeanPooling poolingRule;

  // Test the pooling function.
  BOOST_REQUIRE_EQUAL(poolingRule.Pooling(input), poolingOutput);

  // Test the unpooling function.
  poolingRule.Unpooling(input, input.max(), output);
  bool b = arma::all(arma::vectorise(output) == (input.max() / input.n_elem));
  BOOST_REQUIRE_EQUAL(b, true);
}

BOOST_AUTO_TEST_SUITE_END();
